import os
import sys
import torch
import numpy as np
import math
from imageio import imread
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

sys.path.append(os.path.join(os.getcwd()))  # HACK add the root folder
from .projection import ProjectionHelper

CROP_RATE = 2
# INTRINSICS = [[37.01983*CROP_RATE, 0, 20*CROP_RATE, 0], [0, 38.52470*CROP_RATE, 15.5*CROP_RATE, 0], [0, 0, 1, 0], [0, 0, 0, 1]]
INTRINSICS = [[37.01983*CROP_RATE, 0, 20*CROP_RATE, 0], [0, 38.52470*CROP_RATE, 15.5*CROP_RATE, 0], [0, 0, 1, 0], [0, 0, 0, 1]]

PROJECTOR = ProjectionHelper(INTRINSICS, 0.1, 8.0, [41*CROP_RATE, 32*CROP_RATE], 0.5)
# PROJECTOR = ProjectionHelper(INTRINSICS, 0.1, 4.0, [41, 32], 0.05)


def to_tensor(arr):
    return torch.Tensor(arr).cuda()


def resize_crop_image(image, new_image_dims):
    image_dims = [image.shape[1], image.shape[0]]
    if image_dims == new_image_dims:
        return image
    resize_width = int(math.floor(new_image_dims[1] * float(image_dims[0]) / float(image_dims[1])))
    image = transforms.Resize([new_image_dims[1], resize_width], interpolation=InterpolationMode.NEAREST)(Image.fromarray(image))
    image = transforms.CenterCrop([new_image_dims[1], new_image_dims[0]])(image)
    image = np.array(image)

    return image


def load_image(file, image_dims):
    image = imread(file)
    # preprocess
    image = resize_crop_image(image, image_dims)
    if len(image.shape) == 3:  # color image
        image = np.transpose(image, [2, 0, 1])  # move feature to front
        image = transforms.Normalize(mean=[0.496342, 0.466664, 0.440796], std=[0.277856, 0.28623, 0.291129])(
            torch.Tensor(image.astype(np.float32) / 255.0))
    elif len(image.shape) == 2:  # label image
        pass
    else:
        raise

    return image


def load_pose(filename):
    lines = open(filename).read().splitlines()
    assert len(lines) == 4
    lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]

    return np.asarray(lines).astype(np.float32)


def load_depth(file, image_dims):
    depth_image = imread(file)
    # preprocess
    depth_image = resize_crop_image(depth_image, image_dims)
    depth_image = depth_image.astype(np.float32) / 1000.0

    return depth_image


def compute_projection(points, depth, camera_to_world):
    """
        :param points: tensor containing all points of the point cloud (num_points, 3)
        :param depth: depth map (size: proj_image)
        :param camera_to_world: camera pose (4, 4)
        
        :return indices_3d (array with point indices that correspond to a pixel),
        :return indices_2d (array with pixel indices that correspond to a point)

        note:
            the first digit of indices represents the number of relevant points
            the rest digits are for the projection mapping
    """
    num_points = points.shape[0]
    num_frames = depth.shape[0]
    indices_3ds = torch.zeros(num_frames, num_points + 1).long().cuda()
    indices_2ds = torch.zeros(num_frames, num_points + 1).long().cuda()

    for i in range(num_frames):
        indices = PROJECTOR.compute_projection(to_tensor(points), to_tensor(depth[i]), to_tensor(camera_to_world[i]))
        if indices:
            indices_3ds[i] = indices[0].long()
            indices_2ds[i] = indices[1].long()

            num_valid_points = indices_2ds[i][0].item()
            projected_pixels = indices_2ds[i][1:num_valid_points + 1].cpu().numpy()
            valid_3d_points = indices_3ds[i][1:num_valid_points + 1].cpu().numpy()
            y_coords = projected_pixels // depth[i].shape[1]  # y
            x_coords = projected_pixels % depth[i].shape[1]  # x
            proj_mask = torch.zeros(depth[i].shape)
            proj_mask[y_coords, x_coords] = 1
            # print("found {} mappings in {} points from frame {}".format(indices_3ds[i][0], num_points, i))

    return proj_mask, valid_3d_points, projected_pixels


def projection_interface(scene_points_path, scene_image_path, scene_depht_path, scene_pose_path):
    """"""
    # setting
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

    scene = np.load(scene_points_path)
    # np.savetxt('scene.txt', scene)
    points = scene[:, :3]
    # load frames
    scene_images = np.zeros((1, 3, 256*CROP_RATE, 328*CROP_RATE))  # (1, 3, 256, 328)
    scene_depths = np.zeros((1, 32*CROP_RATE, 41*CROP_RATE))       # (1, 32, 41)
    scene_poses = np.zeros((1, 4, 4))

    scene_images[0] = load_image(scene_image_path, [328*CROP_RATE, 256*CROP_RATE])  # [328, 256]
    scene_depths[0] = load_depth(scene_depht_path, [41*CROP_RATE, 32*CROP_RATE])  # [41, 32]
    scene_poses[0] = load_pose(scene_pose_path)

    # compute projections for each chunk
    projection_2d_mask, projection_3d_idx, projected_pixels = compute_projection(points, scene_depths, scene_poses)
    # np.savetxt('proj.txt', scene[projection_3d_idx])
    return projection_2d_mask.numpy(), projection_3d_idx, projected_pixels